# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd

from .rsp_rav import rsp_rav
from .rsp_rrv import rsp_rrv


def rsp_intervalrelated(data, sampling_rate=1000):
    """**State-related Respiration Indices**

    Performs RSP analysis on longer periods of data (typically > 10 seconds), such as resting-state
    data. It returns a dataframe containing features and characteristics of respiration in that
    interval.

    Parameters
    ----------
    data : DataFrame or dict
        A DataFrame containing the different processed signal(s) as different columns, typically
        generated by :func:`.rsp_process` or :func:`.bio_process`. Can also take a dict containing
        sets of separately processed DataFrames.
    sampling_rate : int
        The sampling frequency of the signal (in Hz, i.e., samples/second).

    Returns
    -------
    DataFrame
        A dataframe containing the analyzed RSP features.
        The analyzed features consist of the following:

        .. codebookadd::
            RSP_Rate_Mean|The mean respiratory rate.
            RSP_Amplitude_Mean|The mean respiratory amplitude.
            RSP_Phase_Duration_Inspiration|The average inspiration duration.
            RSP_Phase_Duration_Expiration|The average expiration duration.
            RSP_Phase_Duration_Ratio|The inspiration-to-expiratory time ratio (I/E).

        * ``"RSP_RRV"``: the different respiratory rate variability metrices.
          See :func:`.rsp_rrv` docstrings for details.

    See Also
    --------
    bio_process, rsp_eventrelated, rsp_rrv, rsp_rav

    Examples
    ----------
    .. ipython:: python

      import neurokit2 as nk

      # Download data
      data = nk.data("bio_resting_5min_100hz")

      # Process the data
      df, info = nk.rsp_process(data["RSP"], sampling_rate=100)

      # Single dataframe is passed
      nk.rsp_intervalrelated(df, sampling_rate=100)

      epochs = nk.epochs_create(df, events=[0, 15000], sampling_rate=100, epochs_end=150)
      nk.rsp_intervalrelated(epochs, sampling_rate=100)

    """

    # If one interval dataframe
    if isinstance(data, pd.DataFrame):
        intervals = _rsp_intervalrelated_features(data, sampling_rate)
        intervals = pd.DataFrame.from_dict(intervals, orient="index").T

    # If data is a dict (containing multiple intervals)
    elif isinstance(data, dict):
        intervals = {}
        for index in data:
            intervals[index] = {}  # Initialize empty container

            # Add label info
            intervals[index]["Label"] = data[index]["Label"].iloc[0]

            # Features
            intervals[index] = _rsp_intervalrelated_features(
                data[index], sampling_rate, intervals[index]
            )

        intervals = pd.DataFrame.from_dict(intervals, orient="index")

    return intervals


# =============================================================================
# Internals
# =============================================================================


def _rsp_intervalrelated_features(data, sampling_rate, output={}):
    # Sanitize input
    colnames = data.columns.values

    if "RSP_Rate" in colnames:
        output["RSP_Rate_Mean"] = np.nanmean(data["RSP_Rate"].values)
        rrv = rsp_rrv(data, sampling_rate=sampling_rate)
        output.update(rrv.to_dict(orient="records")[0])

    if "RSP_Amplitude" in colnames:
        rav = rsp_rav(data["RSP_Amplitude"].values, peaks=data)
        output.update(rav.to_dict(orient="records")[0])

    if "RSP_RVT" in colnames:
        output["RSP_RVT"] = np.nanmean(data["RSP_RVT"].values)

    if "RSP_Symmetry_PeakTrough" in colnames:
        output["RSP_Symmetry_PeakTrough"] = np.nanmean(
            data["RSP_Symmetry_PeakTrough"].values
        )
        output["RSP_Symmetry_RiseDecay"] = np.nanmean(
            data["RSP_Symmetry_RiseDecay"].values
        )

    if "RSP_Phase" in colnames:
        # Extract inspiration durations
        insp_phases = data[data["RSP_Phase"] == 1]
        insp_start = insp_phases.index[insp_phases["RSP_Phase_Completion"] == 0]
        insp_end = insp_phases.index[insp_phases["RSP_Phase_Completion"] == 1]

        # Check that start of phase is before end of phase
        if insp_start[0] > insp_end[0]:
            insp_end = insp_end[1:]

        # Check for unequal lengths
        diff = abs(len(insp_start) - len(insp_end))
        if len(insp_start) > len(insp_end):
            insp_start = insp_start[
                : len(insp_start) - diff
            ]  # remove extra start points
        elif len(insp_end) > len(insp_start):
            insp_end = insp_end[: len(insp_end) - diff]  # remove extra end points

        insp_times = np.array(insp_end - insp_start) / sampling_rate

        # Extract expiration durations
        exp_phases = data[data["RSP_Phase"] == 0]
        exp_start = exp_phases.index[exp_phases["RSP_Phase_Completion"] == 0]
        exp_end = exp_phases.index[exp_phases["RSP_Phase_Completion"] == 1]

        # Check that start of phase is before end of phase
        if exp_start[0] > exp_end[0]:
            exp_end = exp_end[1:]

        # Check for unequal lengths
        diff = abs(len(exp_start) - len(exp_end))
        if len(exp_start) > len(exp_end):
            exp_start = exp_start[: len(exp_start) - diff]  # remove extra start points
        elif len(exp_end) > len(exp_start):
            exp_end = exp_end[: len(exp_end) - diff]  # remove extra end points

        exp_times = np.array(exp_end - exp_start) / sampling_rate

        output["RSP_Phase_Duration_Inspiration"] = np.mean(insp_times)
        output["RSP_Phase_Duration_Expiration"] = np.mean(exp_times)
        output["RSP_Phase_Duration_Ratio"] = (
            output["RSP_Phase_Duration_Inspiration"]
            / output["RSP_Phase_Duration_Expiration"]
        )

    return output
